package drds.plus.sql_process.parser.visitor.dml;

import drds.plus.parser.abstract_syntax_tree.expression.Expression;
import drds.plus.parser.abstract_syntax_tree.expression.Pair;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.Identifier;
import drds.plus.parser.abstract_syntax_tree.expression.primary.misc.RowValues;
import drds.plus.parser.abstract_syntax_tree.statement.InsertStatement;
import drds.plus.parser.abstract_syntax_tree.statement.Query;
import drds.plus.parser.visitor.EmptyVisitor;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.node.dml.Insert;
import drds.plus.sql_process.abstract_syntax_tree.node.query.TableQuery;
import drds.plus.sql_process.parser.visitor.ExpressionVisitor;
import drds.plus.sql_process.utils.OptimizerUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class InsertVisitor extends EmptyVisitor {

    private Insert insertNode;

    public void visit(InsertStatement insertStatement) {
        TableQuery tableQueryNode = getTableQueryNode(insertStatement);
        String columnNameList = this.getColumnNameList(insertStatement);
        List<RowValues> rowDataList = insertStatement.getRowValuesList();

        if (rowDataList != null && rowDataList.size() == 1) {
            RowValues rowValues = rowDataList.get(0);
            Object[] $rowValues = getRowValues(rowValues);
            this.insertNode = tableQueryNode.insert(columnNameList, $rowValues);
        } else if (rowDataList != null) {
            List<List<Object>> rowValuesList = new ArrayList<List<Object>>();
            for (RowValues rowValues : rowDataList) {
                Object[] $rowValues = getRowValues(rowValues);
                rowValuesList.add(Arrays.asList($rowValues));
            }
            this.insertNode = tableQueryNode.insert(columnNameList, rowValuesList);
        } else {
            String[] columns = columnNameList.split(" ");
            List<Item> itemList = new ArrayList<Item>();
            for (String name : columns) {
                Item item = OptimizerUtils.createColumnFromString(name);
                itemList.add(item);
            }
            this.insertNode = new Insert(tableQueryNode);
            this.insertNode.setColumnNameList(itemList);
            // 暂时不支持子表的查询
            Query query = insertStatement.getQuery();
            if (query != null) {
                SelectVisitor selectVisitor = new SelectVisitor();
                query.accept(selectVisitor);
                insertNode.setQuery(selectVisitor.getQuery());
            } else {
                throw new IllegalStateException("insert语句values部分如果没有值则必须包含select");
            }
        }

        List<Pair<Identifier, Expression>> duplicateUpdate = insertStatement.getDuplicateUpdate();
        if (duplicateUpdate != null && !duplicateUpdate.isEmpty()) {
            String[] updateColumns = new String[duplicateUpdate.size()];
            Object[] updateValues = new Comparable[duplicateUpdate.size()];
            for (int i = 0; i < duplicateUpdate.size(); i++) {
                Pair<Identifier, Expression> pair = duplicateUpdate.get(i);
                updateColumns[i] = (pair.getKey().getText());
                ExpressionVisitor expressionVisitor = new ExpressionVisitor();
                pair.getValue().accept(expressionVisitor);
                updateValues[i] = expressionVisitor.getObject();// 可能为function
            }
            this.insertNode.duplicateUpdate(updateColumns, updateValues);
        }


    }

    private TableQuery getTableQueryNode(InsertStatement insertStatement) {
        return new TableQuery(insertStatement.getTableName().getText());
    }

    private String getColumnNameList(InsertStatement insertStatement) {
        List<Identifier> columnNameList = insertStatement.getColumnNameList();
        StringBuilder sb = new StringBuilder();
        if (columnNameList != null && columnNameList.size() != 0) {
            for (int i = 0; i < columnNameList.size(); i++) {
                if (i > 0) {
                    sb.append(" ");
                }
                sb.append(columnNameList.get(i).getText());
            }
        }
        return sb.toString();
    }

    private Object[] getRowValues(RowValues rowValues) {
        Object[] $rowValues = new Object[rowValues.getRowValueList().size()];
        for (int i = 0; i < rowValues.getRowValueList().size(); i++) {
            ExpressionVisitor rvev = new ExpressionVisitor();
            rowValues.getRowValueList().get(i).accept(rvev);
            $rowValues[i] = rvev.getObject();
        }
        return $rowValues;
    }

    public Insert getInsertNode() {
        return insertNode;
    }
}
